{
 "cells": [
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# Test Metrics"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 38,
   "metadata": {},
   "outputs": [],
   "source": [
    "import warnings\n",
    "warnings.filterwarnings('ignore')\n",
    "\n",
    "import numpy as np\n",
    "\n",
    "from sklearn.datasets import load_boston\n",
    "from sklearn.datasets import load_diabetes\n",
    "from sklearn.datasets import load_digits\n",
    "from sklearn.datasets import load_linnerud\n",
    "from sklearn.datasets import load_wine\n",
    "from sklearn.datasets import load_iris\n",
    "from sklearn.datasets import load_breast_cancer\n",
    "\n",
    "from keras.datasets import mnist\n",
    "from keras.datasets import cifar10\n",
    "\n",
    "from keras.utils import to_categorical\n",
    "from hyperactive import RandomSearchOptimizer"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 39,
   "metadata": {},
   "outputs": [],
   "source": [
    "metrics_sklearn = [\n",
    "    \"accuracy_score\",\n",
    "    \"balanced_accuracy_score\",\n",
    "    \"average_precision_score\",\n",
    "    \"f1_score\",\n",
    "    \"precision_score\",\n",
    "    \"recall_score\",\n",
    "    \"jaccard_score\",\n",
    "    \"roc_auc_score\",\n",
    "    \"explained_variance_score\",\n",
    "    \"r2_score\",\n",
    "    \"brier_score_loss\",\n",
    "    \"log_loss\",\n",
    "    \"max_error\",\n",
    "    \"mean_absolute_error\",\n",
    "    \"mean_squared_error\",\n",
    "    \"mean_squared_log_error\",\n",
    "    \"median_absolute_error\",\n",
    "]\n",
    "\n",
    "metrics_keras = [\n",
    "    \"accuracy\",\n",
    "    \"binary_accuracy\",\n",
    "    \"categorical_accuracy\",\n",
    "    \"sparse_categorical_accuracy\",\n",
    "    \"top_k_categorical_accuracy\",\n",
    "    \"sparse_top_k_categorical_accuracy\",\n",
    "    \"mean_squared_error\",\n",
    "    \"mean_absolute_error\",\n",
    "    \"mean_absolute_percentage_error\",\n",
    "    \"mean_squared_logarithmic_error\",\n",
    "    \"squared_hinge\",\n",
    "    \"hinge\",\n",
    "    \"categorical_hinge\",\n",
    "    \"logcosh\",\n",
    "    \"categorical_crossentropy\",\n",
    "    \"sparse_categorical_crossentropy\",\n",
    "    \"binary_crossentropy\",\n",
    "    \"kullback_leibler_divergence\",\n",
    "    \"poisson\",\n",
    "    \"cosine_proximity\",\n",
    "]"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 40,
   "metadata": {},
   "outputs": [],
   "source": [
    "config_sklearn_cla = {\n",
    "    \"sklearn.ensemble.RandomForestClassifier\": {\"n_estimators\": range(1, 10, 1)}\n",
    "}\n",
    "\n",
    "config_sklearn_reg = {\n",
    "    \"sklearn.ensemble.RandomForestRegressor\": {\"n_estimators\": range(1, 10, 1)}\n",
    "}\n",
    "\n",
    "config_keras_cla = {\n",
    "    \"keras.compile.0\": {\"loss\": [\"categorical_crossentropy\"], \"optimizer\": [\"adam\"]},\n",
    "    \"keras.fit.0\": {\"epochs\": [1], \"batch_size\": [500], \"verbose\": [0]},\n",
    "    \"keras.layers.Conv2D.1\": {\n",
    "        \"filters\": [32, 64, 128],\n",
    "        \"kernel_size\": range(3, 4),\n",
    "        \"activation\": [\"relu\"],\n",
    "    },\n",
    "    \"keras.layers.MaxPooling2D.2\": {\"pool_size\": [(2, 2)]},\n",
    "    \"keras.layers.Conv2D.3\": {\n",
    "        \"filters\": [16, 32, 64],\n",
    "        \"kernel_size\": [3],\n",
    "        \"activation\": [\"relu\"],\n",
    "    },\n",
    "    \"keras.layers.MaxPooling2D.4\": {\"pool_size\": [(2, 2)]},\n",
    "    \"keras.layers.Flatten.5\": {},\n",
    "    \"keras.layers.Dense.6\": {\"units\": range(30, 200, 10), \"activation\": [\"softmax\"]},\n",
    "    \"keras.layers.Dropout.7\": {\"rate\": list(np.arange(0.4, 0.8, 0.1))},\n",
    "    \"keras.layers.Dense.8\": {\"units\": [10], \"activation\": [\"softmax\"]},\n",
    "}"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 41,
   "metadata": {},
   "outputs": [],
   "source": [
    "def test_metrics(metrics, search_config, X, y):\n",
    "    metric_list = []\n",
    "    for metric in metrics:\n",
    "        opt = RandomSearchOptimizer(search_config, n_iter=0, metric=metric, verbosity=0)\n",
    "        try:\n",
    "            opt.fit(X, y)\n",
    "            metric_list.append(metric)\n",
    "        except Exception:\n",
    "            pass\n",
    "        \n",
    "    return metric_list"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# Iris data"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 42,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "application/vnd.jupyter.widget-view+json": {
       "model_id": "",
       "version_major": 2,
       "version_minor": 0
      },
      "text/plain": [
       "HBox(children=(IntProgress(value=1, bar_style='info', description='Search 0', max=1, style=ProgressStyle(descr…"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "\r"
     ]
    },
    {
     "data": {
      "application/vnd.jupyter.widget-view+json": {
       "model_id": "",
       "version_major": 2,
       "version_minor": 0
      },
      "text/plain": [
       "HBox(children=(IntProgress(value=1, bar_style='info', description='Search 0', max=1, style=ProgressStyle(descr…"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "\r"
     ]
    },
    {
     "data": {
      "application/vnd.jupyter.widget-view+json": {
       "model_id": "",
       "version_major": 2,
       "version_minor": 0
      },
      "text/plain": [
       "HBox(children=(IntProgress(value=1, bar_style='info', description='Search 0', max=1, style=ProgressStyle(descr…"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "\r"
     ]
    },
    {
     "data": {
      "application/vnd.jupyter.widget-view+json": {
       "model_id": "",
       "version_major": 2,
       "version_minor": 0
      },
      "text/plain": [
       "HBox(children=(IntProgress(value=1, bar_style='info', description='Search 0', max=1, style=ProgressStyle(descr…"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "\r"
     ]
    },
    {
     "data": {
      "application/vnd.jupyter.widget-view+json": {
       "model_id": "",
       "version_major": 2,
       "version_minor": 0
      },
      "text/plain": [
       "HBox(children=(IntProgress(value=1, bar_style='info', description='Search 0', max=1, style=ProgressStyle(descr…"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "\r"
     ]
    },
    {
     "data": {
      "application/vnd.jupyter.widget-view+json": {
       "model_id": "",
       "version_major": 2,
       "version_minor": 0
      },
      "text/plain": [
       "HBox(children=(IntProgress(value=1, bar_style='info', description='Search 0', max=1, style=ProgressStyle(descr…"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "\r"
     ]
    },
    {
     "data": {
      "application/vnd.jupyter.widget-view+json": {
       "model_id": "",
       "version_major": 2,
       "version_minor": 0
      },
      "text/plain": [
       "HBox(children=(IntProgress(value=1, bar_style='info', description='Search 0', max=1, style=ProgressStyle(descr…"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "\r"
     ]
    },
    {
     "data": {
      "application/vnd.jupyter.widget-view+json": {
       "model_id": "",
       "version_major": 2,
       "version_minor": 0
      },
      "text/plain": [
       "HBox(children=(IntProgress(value=1, bar_style='info', description='Search 0', max=1, style=ProgressStyle(descr…"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "\r"
     ]
    },
    {
     "data": {
      "application/vnd.jupyter.widget-view+json": {
       "model_id": "",
       "version_major": 2,
       "version_minor": 0
      },
      "text/plain": [
       "HBox(children=(IntProgress(value=1, bar_style='info', description='Search 0', max=1, style=ProgressStyle(descr…"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "\r",
      "target shape: (150,)\n",
      "working metrics:\n",
      "accuracy_score\n",
      "balanced_accuracy_score\n",
      "explained_variance_score\n",
      "r2_score\n",
      "max_error\n",
      "mean_absolute_error\n",
      "mean_squared_error\n",
      "mean_squared_log_error\n",
      "median_absolute_error\n"
     ]
    }
   ],
   "source": [
    "iris_data = load_iris()\n",
    "X, y = iris_data.data, iris_data.target\n",
    "\n",
    "metric_list = test_metrics(metrics_sklearn, config_sklearn_cla, X, y)\n",
    "\n",
    "print(\"target shape:\", y.shape)\n",
    "print(\"working metrics:\")\n",
    "for metric in metric_list:\n",
    "    print(metric)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 43,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "application/vnd.jupyter.widget-view+json": {
       "model_id": "",
       "version_major": 2,
       "version_minor": 0
      },
      "text/plain": [
       "HBox(children=(IntProgress(value=1, bar_style='info', description='Search 0', max=1, style=ProgressStyle(descr…"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "\r"
     ]
    },
    {
     "data": {
      "application/vnd.jupyter.widget-view+json": {
       "model_id": "",
       "version_major": 2,
       "version_minor": 0
      },
      "text/plain": [
       "HBox(children=(IntProgress(value=1, bar_style='info', description='Search 0', max=1, style=ProgressStyle(descr…"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "\r"
     ]
    },
    {
     "data": {
      "application/vnd.jupyter.widget-view+json": {
       "model_id": "",
       "version_major": 2,
       "version_minor": 0
      },
      "text/plain": [
       "HBox(children=(IntProgress(value=1, bar_style='info', description='Search 0', max=1, style=ProgressStyle(descr…"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "\r"
     ]
    },
    {
     "data": {
      "application/vnd.jupyter.widget-view+json": {
       "model_id": "",
       "version_major": 2,
       "version_minor": 0
      },
      "text/plain": [
       "HBox(children=(IntProgress(value=1, bar_style='info', description='Search 0', max=1, style=ProgressStyle(descr…"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "\r"
     ]
    },
    {
     "data": {
      "application/vnd.jupyter.widget-view+json": {
       "model_id": "",
       "version_major": 2,
       "version_minor": 0
      },
      "text/plain": [
       "HBox(children=(IntProgress(value=1, bar_style='info', description='Search 0', max=1, style=ProgressStyle(descr…"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "\r"
     ]
    },
    {
     "data": {
      "application/vnd.jupyter.widget-view+json": {
       "model_id": "",
       "version_major": 2,
       "version_minor": 0
      },
      "text/plain": [
       "HBox(children=(IntProgress(value=1, bar_style='info', description='Search 0', max=1, style=ProgressStyle(descr…"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "\r"
     ]
    },
    {
     "data": {
      "application/vnd.jupyter.widget-view+json": {
       "model_id": "",
       "version_major": 2,
       "version_minor": 0
      },
      "text/plain": [
       "HBox(children=(IntProgress(value=1, bar_style='info', description='Search 0', max=1, style=ProgressStyle(descr…"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "\r"
     ]
    },
    {
     "data": {
      "application/vnd.jupyter.widget-view+json": {
       "model_id": "",
       "version_major": 2,
       "version_minor": 0
      },
      "text/plain": [
       "HBox(children=(IntProgress(value=1, bar_style='info', description='Search 0', max=1, style=ProgressStyle(descr…"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "\r",
      "target shape: (150, 3)\n",
      "working metrics:\n",
      "accuracy_score\n",
      "average_precision_score\n",
      "explained_variance_score\n",
      "r2_score\n",
      "log_loss\n",
      "mean_absolute_error\n",
      "mean_squared_error\n",
      "mean_squared_log_error\n"
     ]
    }
   ],
   "source": [
    "X, y = iris_data.data, iris_data.target\n",
    "y = to_categorical(y, 3)\n",
    "\n",
    "metric_list = test_metrics(metrics_sklearn, config_sklearn_cla, X, y)\n",
    "\n",
    "print(\"target shape:\", y.shape)\n",
    "print(\"working metrics:\")\n",
    "for metric in metric_list:\n",
    "    print(metric)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# Boston data"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 46,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "application/vnd.jupyter.widget-view+json": {
       "model_id": "",
       "version_major": 2,
       "version_minor": 0
      },
      "text/plain": [
       "HBox(children=(IntProgress(value=1, bar_style='info', description='Search 0', max=1, style=ProgressStyle(descr…"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "\r"
     ]
    },
    {
     "data": {
      "application/vnd.jupyter.widget-view+json": {
       "model_id": "",
       "version_major": 2,
       "version_minor": 0
      },
      "text/plain": [
       "HBox(children=(IntProgress(value=1, bar_style='info', description='Search 0', max=1, style=ProgressStyle(descr…"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "\r"
     ]
    },
    {
     "data": {
      "application/vnd.jupyter.widget-view+json": {
       "model_id": "",
       "version_major": 2,
       "version_minor": 0
      },
      "text/plain": [
       "HBox(children=(IntProgress(value=1, bar_style='info', description='Search 0', max=1, style=ProgressStyle(descr…"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "\r"
     ]
    },
    {
     "data": {
      "application/vnd.jupyter.widget-view+json": {
       "model_id": "",
       "version_major": 2,
       "version_minor": 0
      },
      "text/plain": [
       "HBox(children=(IntProgress(value=1, bar_style='info', description='Search 0', max=1, style=ProgressStyle(descr…"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "\r"
     ]
    },
    {
     "data": {
      "application/vnd.jupyter.widget-view+json": {
       "model_id": "",
       "version_major": 2,
       "version_minor": 0
      },
      "text/plain": [
       "HBox(children=(IntProgress(value=1, bar_style='info', description='Search 0', max=1, style=ProgressStyle(descr…"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "\r"
     ]
    },
    {
     "data": {
      "application/vnd.jupyter.widget-view+json": {
       "model_id": "",
       "version_major": 2,
       "version_minor": 0
      },
      "text/plain": [
       "HBox(children=(IntProgress(value=1, bar_style='info', description='Search 0', max=1, style=ProgressStyle(descr…"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "\r"
     ]
    },
    {
     "data": {
      "application/vnd.jupyter.widget-view+json": {
       "model_id": "",
       "version_major": 2,
       "version_minor": 0
      },
      "text/plain": [
       "HBox(children=(IntProgress(value=1, bar_style='info', description='Search 0', max=1, style=ProgressStyle(descr…"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "\r",
      "target shape: (506,)\n",
      "working metrics:\n",
      "explained_variance_score\n",
      "r2_score\n",
      "max_error\n",
      "mean_absolute_error\n",
      "mean_squared_error\n",
      "mean_squared_log_error\n",
      "median_absolute_error\n"
     ]
    }
   ],
   "source": [
    "boston_data = load_boston()\n",
    "X, y = boston_data.data, boston_data.target\n",
    "\n",
    "metric_list = test_metrics(metrics_sklearn, config_sklearn_reg, X, y)\n",
    "\n",
    "print(\"target shape:\", y.shape)\n",
    "print(\"working metrics:\")\n",
    "for metric in metric_list:\n",
    "    print(metric)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# Diabetes data"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 47,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "application/vnd.jupyter.widget-view+json": {
       "model_id": "",
       "version_major": 2,
       "version_minor": 0
      },
      "text/plain": [
       "HBox(children=(IntProgress(value=1, bar_style='info', description='Search 0', max=1, style=ProgressStyle(descr…"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "\r"
     ]
    },
    {
     "data": {
      "application/vnd.jupyter.widget-view+json": {
       "model_id": "",
       "version_major": 2,
       "version_minor": 0
      },
      "text/plain": [
       "HBox(children=(IntProgress(value=1, bar_style='info', description='Search 0', max=1, style=ProgressStyle(descr…"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "\r"
     ]
    },
    {
     "data": {
      "application/vnd.jupyter.widget-view+json": {
       "model_id": "",
       "version_major": 2,
       "version_minor": 0
      },
      "text/plain": [
       "HBox(children=(IntProgress(value=1, bar_style='info', description='Search 0', max=1, style=ProgressStyle(descr…"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "\r"
     ]
    },
    {
     "data": {
      "application/vnd.jupyter.widget-view+json": {
       "model_id": "",
       "version_major": 2,
       "version_minor": 0
      },
      "text/plain": [
       "HBox(children=(IntProgress(value=1, bar_style='info', description='Search 0', max=1, style=ProgressStyle(descr…"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "\r"
     ]
    },
    {
     "data": {
      "application/vnd.jupyter.widget-view+json": {
       "model_id": "",
       "version_major": 2,
       "version_minor": 0
      },
      "text/plain": [
       "HBox(children=(IntProgress(value=1, bar_style='info', description='Search 0', max=1, style=ProgressStyle(descr…"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "\r"
     ]
    },
    {
     "data": {
      "application/vnd.jupyter.widget-view+json": {
       "model_id": "",
       "version_major": 2,
       "version_minor": 0
      },
      "text/plain": [
       "HBox(children=(IntProgress(value=1, bar_style='info', description='Search 0', max=1, style=ProgressStyle(descr…"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "\r"
     ]
    },
    {
     "data": {
      "application/vnd.jupyter.widget-view+json": {
       "model_id": "",
       "version_major": 2,
       "version_minor": 0
      },
      "text/plain": [
       "HBox(children=(IntProgress(value=1, bar_style='info', description='Search 0', max=1, style=ProgressStyle(descr…"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "\r",
      "target shape: (442,)\n",
      "working metrics:\n",
      "explained_variance_score\n",
      "r2_score\n",
      "max_error\n",
      "mean_absolute_error\n",
      "mean_squared_error\n",
      "mean_squared_log_error\n",
      "median_absolute_error\n"
     ]
    }
   ],
   "source": [
    "diabetes_data = load_diabetes()\n",
    "X, y = diabetes_data.data, diabetes_data.target\n",
    "\n",
    "metric_list = test_metrics(metrics_sklearn, config_sklearn_reg, X, y)\n",
    "\n",
    "print(\"target shape:\", y.shape)\n",
    "print(\"working metrics:\")\n",
    "for metric in metric_list:\n",
    "    print(metric)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# Digits data"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 48,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "application/vnd.jupyter.widget-view+json": {
       "model_id": "",
       "version_major": 2,
       "version_minor": 0
      },
      "text/plain": [
       "HBox(children=(IntProgress(value=1, bar_style='info', description='Search 0', max=1, style=ProgressStyle(descr…"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "\r"
     ]
    },
    {
     "data": {
      "application/vnd.jupyter.widget-view+json": {
       "model_id": "",
       "version_major": 2,
       "version_minor": 0
      },
      "text/plain": [
       "HBox(children=(IntProgress(value=1, bar_style='info', description='Search 0', max=1, style=ProgressStyle(descr…"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "\r"
     ]
    },
    {
     "data": {
      "application/vnd.jupyter.widget-view+json": {
       "model_id": "",
       "version_major": 2,
       "version_minor": 0
      },
      "text/plain": [
       "HBox(children=(IntProgress(value=1, bar_style='info', description='Search 0', max=1, style=ProgressStyle(descr…"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "\r"
     ]
    },
    {
     "data": {
      "application/vnd.jupyter.widget-view+json": {
       "model_id": "",
       "version_major": 2,
       "version_minor": 0
      },
      "text/plain": [
       "HBox(children=(IntProgress(value=1, bar_style='info', description='Search 0', max=1, style=ProgressStyle(descr…"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "\r"
     ]
    },
    {
     "data": {
      "application/vnd.jupyter.widget-view+json": {
       "model_id": "",
       "version_major": 2,
       "version_minor": 0
      },
      "text/plain": [
       "HBox(children=(IntProgress(value=1, bar_style='info', description='Search 0', max=1, style=ProgressStyle(descr…"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "\r"
     ]
    },
    {
     "data": {
      "application/vnd.jupyter.widget-view+json": {
       "model_id": "",
       "version_major": 2,
       "version_minor": 0
      },
      "text/plain": [
       "HBox(children=(IntProgress(value=1, bar_style='info', description='Search 0', max=1, style=ProgressStyle(descr…"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "\r"
     ]
    },
    {
     "data": {
      "application/vnd.jupyter.widget-view+json": {
       "model_id": "",
       "version_major": 2,
       "version_minor": 0
      },
      "text/plain": [
       "HBox(children=(IntProgress(value=1, bar_style='info', description='Search 0', max=1, style=ProgressStyle(descr…"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "\r"
     ]
    },
    {
     "data": {
      "application/vnd.jupyter.widget-view+json": {
       "model_id": "",
       "version_major": 2,
       "version_minor": 0
      },
      "text/plain": [
       "HBox(children=(IntProgress(value=1, bar_style='info', description='Search 0', max=1, style=ProgressStyle(descr…"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "\r"
     ]
    },
    {
     "data": {
      "application/vnd.jupyter.widget-view+json": {
       "model_id": "",
       "version_major": 2,
       "version_minor": 0
      },
      "text/plain": [
       "HBox(children=(IntProgress(value=1, bar_style='info', description='Search 0', max=1, style=ProgressStyle(descr…"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "\r",
      "target shape: (1797,)\n",
      "working metrics:\n",
      "accuracy_score\n",
      "balanced_accuracy_score\n",
      "explained_variance_score\n",
      "r2_score\n",
      "max_error\n",
      "mean_absolute_error\n",
      "mean_squared_error\n",
      "mean_squared_log_error\n",
      "median_absolute_error\n"
     ]
    }
   ],
   "source": [
    "digits_data = load_digits()\n",
    "X, y = digits_data.data, digits_data.target\n",
    "\n",
    "metric_list = test_metrics(metrics_sklearn, config_sklearn_cla, X, y)\n",
    "\n",
    "print(\"target shape:\", y.shape)\n",
    "print(\"working metrics:\")\n",
    "for metric in metric_list:\n",
    "    print(metric)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# Linnerud data"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 49,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "application/vnd.jupyter.widget-view+json": {
       "model_id": "",
       "version_major": 2,
       "version_minor": 0
      },
      "text/plain": [
       "HBox(children=(IntProgress(value=1, bar_style='info', description='Search 0', max=1, style=ProgressStyle(descr…"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "\r"
     ]
    },
    {
     "data": {
      "application/vnd.jupyter.widget-view+json": {
       "model_id": "",
       "version_major": 2,
       "version_minor": 0
      },
      "text/plain": [
       "HBox(children=(IntProgress(value=1, bar_style='info', description='Search 0', max=1, style=ProgressStyle(descr…"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "\r"
     ]
    },
    {
     "data": {
      "application/vnd.jupyter.widget-view+json": {
       "model_id": "",
       "version_major": 2,
       "version_minor": 0
      },
      "text/plain": [
       "HBox(children=(IntProgress(value=1, bar_style='info', description='Search 0', max=1, style=ProgressStyle(descr…"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "\r"
     ]
    },
    {
     "data": {
      "application/vnd.jupyter.widget-view+json": {
       "model_id": "",
       "version_major": 2,
       "version_minor": 0
      },
      "text/plain": [
       "HBox(children=(IntProgress(value=1, bar_style='info', description='Search 0', max=1, style=ProgressStyle(descr…"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "\r"
     ]
    },
    {
     "data": {
      "application/vnd.jupyter.widget-view+json": {
       "model_id": "",
       "version_major": 2,
       "version_minor": 0
      },
      "text/plain": [
       "HBox(children=(IntProgress(value=1, bar_style='info', description='Search 0', max=1, style=ProgressStyle(descr…"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "\r",
      "target shape: (20, 3)\n",
      "working metrics:\n",
      "explained_variance_score\n",
      "r2_score\n",
      "mean_absolute_error\n",
      "mean_squared_error\n",
      "mean_squared_log_error\n"
     ]
    }
   ],
   "source": [
    "linnerud_data = load_linnerud()\n",
    "X, y = linnerud_data.data, linnerud_data.target\n",
    "\n",
    "metric_list = test_metrics(metrics_sklearn, config_sklearn_reg, X, y)\n",
    "\n",
    "print(\"target shape:\", y.shape)\n",
    "print(\"working metrics:\")\n",
    "for metric in metric_list:\n",
    "    print(metric)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# Wine data"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 50,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "application/vnd.jupyter.widget-view+json": {
       "model_id": "",
       "version_major": 2,
       "version_minor": 0
      },
      "text/plain": [
       "HBox(children=(IntProgress(value=1, bar_style='info', description='Search 0', max=1, style=ProgressStyle(descr…"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "\r"
     ]
    },
    {
     "data": {
      "application/vnd.jupyter.widget-view+json": {
       "model_id": "",
       "version_major": 2,
       "version_minor": 0
      },
      "text/plain": [
       "HBox(children=(IntProgress(value=1, bar_style='info', description='Search 0', max=1, style=ProgressStyle(descr…"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "\r"
     ]
    },
    {
     "data": {
      "application/vnd.jupyter.widget-view+json": {
       "model_id": "",
       "version_major": 2,
       "version_minor": 0
      },
      "text/plain": [
       "HBox(children=(IntProgress(value=1, bar_style='info', description='Search 0', max=1, style=ProgressStyle(descr…"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "\r"
     ]
    },
    {
     "data": {
      "application/vnd.jupyter.widget-view+json": {
       "model_id": "",
       "version_major": 2,
       "version_minor": 0
      },
      "text/plain": [
       "HBox(children=(IntProgress(value=1, bar_style='info', description='Search 0', max=1, style=ProgressStyle(descr…"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "\r"
     ]
    },
    {
     "data": {
      "application/vnd.jupyter.widget-view+json": {
       "model_id": "",
       "version_major": 2,
       "version_minor": 0
      },
      "text/plain": [
       "HBox(children=(IntProgress(value=1, bar_style='info', description='Search 0', max=1, style=ProgressStyle(descr…"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "\r"
     ]
    },
    {
     "data": {
      "application/vnd.jupyter.widget-view+json": {
       "model_id": "",
       "version_major": 2,
       "version_minor": 0
      },
      "text/plain": [
       "HBox(children=(IntProgress(value=1, bar_style='info', description='Search 0', max=1, style=ProgressStyle(descr…"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "\r"
     ]
    },
    {
     "data": {
      "application/vnd.jupyter.widget-view+json": {
       "model_id": "",
       "version_major": 2,
       "version_minor": 0
      },
      "text/plain": [
       "HBox(children=(IntProgress(value=1, bar_style='info', description='Search 0', max=1, style=ProgressStyle(descr…"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "\r"
     ]
    },
    {
     "data": {
      "application/vnd.jupyter.widget-view+json": {
       "model_id": "",
       "version_major": 2,
       "version_minor": 0
      },
      "text/plain": [
       "HBox(children=(IntProgress(value=1, bar_style='info', description='Search 0', max=1, style=ProgressStyle(descr…"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "\r"
     ]
    },
    {
     "data": {
      "application/vnd.jupyter.widget-view+json": {
       "model_id": "",
       "version_major": 2,
       "version_minor": 0
      },
      "text/plain": [
       "HBox(children=(IntProgress(value=1, bar_style='info', description='Search 0', max=1, style=ProgressStyle(descr…"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "\r",
      "target shape: (178,)\n",
      "working metrics:\n",
      "accuracy_score\n",
      "balanced_accuracy_score\n",
      "explained_variance_score\n",
      "r2_score\n",
      "max_error\n",
      "mean_absolute_error\n",
      "mean_squared_error\n",
      "mean_squared_log_error\n",
      "median_absolute_error\n"
     ]
    }
   ],
   "source": [
    "wine_data = load_wine()\n",
    "X, y = wine_data.data, wine_data.target\n",
    "\n",
    "metric_list = test_metrics(metrics_sklearn, config_sklearn_cla, X, y)\n",
    "\n",
    "print(\"target shape:\", y.shape)\n",
    "print(\"working metrics:\")\n",
    "for metric in metric_list:\n",
    "    print(metric)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# Cancer data"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 51,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "application/vnd.jupyter.widget-view+json": {
       "model_id": "",
       "version_major": 2,
       "version_minor": 0
      },
      "text/plain": [
       "HBox(children=(IntProgress(value=1, bar_style='info', description='Search 0', max=1, style=ProgressStyle(descr…"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "\r"
     ]
    },
    {
     "data": {
      "application/vnd.jupyter.widget-view+json": {
       "model_id": "",
       "version_major": 2,
       "version_minor": 0
      },
      "text/plain": [
       "HBox(children=(IntProgress(value=1, bar_style='info', description='Search 0', max=1, style=ProgressStyle(descr…"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "\r"
     ]
    },
    {
     "data": {
      "application/vnd.jupyter.widget-view+json": {
       "model_id": "",
       "version_major": 2,
       "version_minor": 0
      },
      "text/plain": [
       "HBox(children=(IntProgress(value=1, bar_style='info', description='Search 0', max=1, style=ProgressStyle(descr…"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "\r"
     ]
    },
    {
     "data": {
      "application/vnd.jupyter.widget-view+json": {
       "model_id": "",
       "version_major": 2,
       "version_minor": 0
      },
      "text/plain": [
       "HBox(children=(IntProgress(value=1, bar_style='info', description='Search 0', max=1, style=ProgressStyle(descr…"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "\r"
     ]
    },
    {
     "data": {
      "application/vnd.jupyter.widget-view+json": {
       "model_id": "",
       "version_major": 2,
       "version_minor": 0
      },
      "text/plain": [
       "HBox(children=(IntProgress(value=1, bar_style='info', description='Search 0', max=1, style=ProgressStyle(descr…"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "\r"
     ]
    },
    {
     "data": {
      "application/vnd.jupyter.widget-view+json": {
       "model_id": "",
       "version_major": 2,
       "version_minor": 0
      },
      "text/plain": [
       "HBox(children=(IntProgress(value=1, bar_style='info', description='Search 0', max=1, style=ProgressStyle(descr…"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "\r"
     ]
    },
    {
     "data": {
      "application/vnd.jupyter.widget-view+json": {
       "model_id": "",
       "version_major": 2,
       "version_minor": 0
      },
      "text/plain": [
       "HBox(children=(IntProgress(value=1, bar_style='info', description='Search 0', max=1, style=ProgressStyle(descr…"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "\r"
     ]
    },
    {
     "data": {
      "application/vnd.jupyter.widget-view+json": {
       "model_id": "",
       "version_major": 2,
       "version_minor": 0
      },
      "text/plain": [
       "HBox(children=(IntProgress(value=1, bar_style='info', description='Search 0', max=1, style=ProgressStyle(descr…"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "\r"
     ]
    },
    {
     "data": {
      "application/vnd.jupyter.widget-view+json": {
       "model_id": "",
       "version_major": 2,
       "version_minor": 0
      },
      "text/plain": [
       "HBox(children=(IntProgress(value=1, bar_style='info', description='Search 0', max=1, style=ProgressStyle(descr…"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "\r"
     ]
    },
    {
     "data": {
      "application/vnd.jupyter.widget-view+json": {
       "model_id": "",
       "version_major": 2,
       "version_minor": 0
      },
      "text/plain": [
       "HBox(children=(IntProgress(value=1, bar_style='info', description='Search 0', max=1, style=ProgressStyle(descr…"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "\r"
     ]
    },
    {
     "data": {
      "application/vnd.jupyter.widget-view+json": {
       "model_id": "",
       "version_major": 2,
       "version_minor": 0
      },
      "text/plain": [
       "HBox(children=(IntProgress(value=1, bar_style='info', description='Search 0', max=1, style=ProgressStyle(descr…"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "\r"
     ]
    },
    {
     "data": {
      "application/vnd.jupyter.widget-view+json": {
       "model_id": "",
       "version_major": 2,
       "version_minor": 0
      },
      "text/plain": [
       "HBox(children=(IntProgress(value=1, bar_style='info', description='Search 0', max=1, style=ProgressStyle(descr…"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "\r"
     ]
    },
    {
     "data": {
      "application/vnd.jupyter.widget-view+json": {
       "model_id": "",
       "version_major": 2,
       "version_minor": 0
      },
      "text/plain": [
       "HBox(children=(IntProgress(value=1, bar_style='info', description='Search 0', max=1, style=ProgressStyle(descr…"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "\r"
     ]
    },
    {
     "data": {
      "application/vnd.jupyter.widget-view+json": {
       "model_id": "",
       "version_major": 2,
       "version_minor": 0
      },
      "text/plain": [
       "HBox(children=(IntProgress(value=1, bar_style='info', description='Search 0', max=1, style=ProgressStyle(descr…"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "\r"
     ]
    },
    {
     "data": {
      "application/vnd.jupyter.widget-view+json": {
       "model_id": "",
       "version_major": 2,
       "version_minor": 0
      },
      "text/plain": [
       "HBox(children=(IntProgress(value=1, bar_style='info', description='Search 0', max=1, style=ProgressStyle(descr…"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "\r"
     ]
    },
    {
     "data": {
      "application/vnd.jupyter.widget-view+json": {
       "model_id": "",
       "version_major": 2,
       "version_minor": 0
      },
      "text/plain": [
       "HBox(children=(IntProgress(value=1, bar_style='info', description='Search 0', max=1, style=ProgressStyle(descr…"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "\r"
     ]
    },
    {
     "data": {
      "application/vnd.jupyter.widget-view+json": {
       "model_id": "",
       "version_major": 2,
       "version_minor": 0
      },
      "text/plain": [
       "HBox(children=(IntProgress(value=1, bar_style='info', description='Search 0', max=1, style=ProgressStyle(descr…"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "\r",
      "target shape: (569,)\n",
      "working metrics:\n",
      "accuracy_score\n",
      "balanced_accuracy_score\n",
      "average_precision_score\n",
      "f1_score\n",
      "precision_score\n",
      "recall_score\n",
      "jaccard_score\n",
      "roc_auc_score\n",
      "explained_variance_score\n",
      "r2_score\n",
      "brier_score_loss\n",
      "log_loss\n",
      "max_error\n",
      "mean_absolute_error\n",
      "mean_squared_error\n",
      "mean_squared_log_error\n",
      "median_absolute_error\n"
     ]
    }
   ],
   "source": [
    "cancer_data = load_breast_cancer()\n",
    "X, y = cancer_data.data, cancer_data.target\n",
    "\n",
    "metric_list = test_metrics(metrics_sklearn, config_sklearn_cla, X, y)\n",
    "\n",
    "print(\"target shape:\", y.shape)\n",
    "print(\"working metrics:\")\n",
    "for metric in metric_list:\n",
    "    print(metric)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# MNIST data"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 52,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Epoch 1/1\n",
      "100/100 [==============================] - 8s 76ms/step - loss: 2.2778 - acc: 0.1700\n",
      "Epoch 1/1\n",
      "100/100 [==============================] - 7s 74ms/step - loss: 2.2778 - binary_accuracy: 0.9000\n",
      "Epoch 1/1\n",
      "100/100 [==============================] - 8s 79ms/step - loss: 2.2772 - categorical_accuracy: 0.1600\n",
      "Epoch 1/1\n",
      "100/100 [==============================] - 8s 78ms/step - loss: 2.2778 - top_k_categorical_accuracy: 0.5700\n",
      "Epoch 1/1\n",
      "100/100 [==============================] - 8s 78ms/step - loss: 2.2778 - mean_squared_error: 0.0895\n",
      "Epoch 1/1\n",
      "100/100 [==============================] - 8s 80ms/step - loss: 2.2772 - mean_absolute_error: 0.1794\n",
      "Epoch 1/1\n",
      "100/100 [==============================] - 8s 81ms/step - loss: 2.2778 - mean_absolute_percentage_error: 89707559.0400\n",
      "Epoch 1/1\n",
      "100/100 [==============================] - 8s 81ms/step - loss: 2.2778 - mean_squared_logarithmic_error: 0.0436\n",
      "Epoch 1/1\n",
      "100/100 [==============================] - 8s 84ms/step - loss: 2.2778 - squared_hinge: 0.9805\n",
      "Epoch 1/1\n",
      "100/100 [==============================] - 8s 84ms/step - loss: 2.2778 - hinge: 0.9897\n",
      "Epoch 1/1\n",
      "100/100 [==============================] - 9s 85ms/step - loss: 2.2778 - logcosh: 0.0403\n",
      "Epoch 1/1\n",
      "100/100 [==============================] - 9s 90ms/step - loss: 2.2778 - categorical_crossentropy: 2.2778\n",
      "Epoch 1/1\n",
      "100/100 [==============================] - 9s 89ms/step - loss: 2.2778 - binary_crossentropy: 0.3223\n",
      "Epoch 1/1\n",
      "100/100 [==============================] - 9s 89ms/step - loss: 2.2778 - kullback_leibler_divergence: 2.2778\n",
      "Epoch 1/1\n",
      "100/100 [==============================] - 9s 93ms/step - loss: 2.2778 - poisson: 0.3278\n",
      "Epoch 1/1\n",
      "100/100 [==============================] - 9s 91ms/step - loss: 2.2778 - cosine_proximity: -0.3241\n",
      "target shape: (100, 10)\n",
      "working metrics:\n",
      "accuracy\n",
      "binary_accuracy\n",
      "categorical_accuracy\n",
      "top_k_categorical_accuracy\n",
      "mean_squared_error\n",
      "mean_absolute_error\n",
      "mean_absolute_percentage_error\n",
      "mean_squared_logarithmic_error\n",
      "squared_hinge\n",
      "hinge\n",
      "logcosh\n",
      "categorical_crossentropy\n",
      "binary_crossentropy\n",
      "kullback_leibler_divergence\n",
      "poisson\n",
      "cosine_proximity\n"
     ]
    }
   ],
   "source": [
    "(X, y), (x_test, y_test) = mnist.load_data()\n",
    "X = X[0:100]\n",
    "y = y[0:100]\n",
    "\n",
    "X = X.reshape(100, 28, 28, 1)\n",
    "y = to_categorical(y)\n",
    "\n",
    "metric_list = test_metrics(metrics_keras, config_keras_cla, X, y)\n",
    "\n",
    "print(\"target shape:\", y.shape)\n",
    "print(\"working metrics:\")\n",
    "for metric in metric_list:\n",
    "    print(metric)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# cifar10 data"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 53,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Epoch 1/1\n",
      "100/100 [==============================] - 9s 93ms/step - loss: 2.3081 - acc: 0.0900\n",
      "Epoch 1/1\n",
      "100/100 [==============================] - 9s 93ms/step - loss: 2.3081 - binary_accuracy: 0.9000\n",
      "Epoch 1/1\n",
      "100/100 [==============================] - 10s 96ms/step - loss: 2.3081 - categorical_accuracy: 0.0900\n",
      "Epoch 1/1\n",
      "100/100 [==============================] - 10s 97ms/step - loss: 2.3081 - top_k_categorical_accuracy: 0.5600\n",
      "Epoch 1/1\n",
      "100/100 [==============================] - 10s 98ms/step - loss: 2.3081 - mean_squared_error: 0.0901\n",
      "Epoch 1/1\n",
      "100/100 [==============================] - 10s 99ms/step - loss: 2.3081 - mean_absolute_error: 0.1800\n",
      "Epoch 1/1\n",
      "100/100 [==============================] - 11s 108ms/step - loss: 2.2984 - mean_absolute_percentage_error: 89894737.2800\n",
      "Epoch 1/1\n",
      "100/100 [==============================] - 10s 101ms/step - loss: 2.3081 - mean_squared_logarithmic_error: 0.0440\n",
      "Epoch 1/1\n",
      "100/100 [==============================] - 10s 103ms/step - loss: 2.3081 - squared_hinge: 0.9810\n",
      "Epoch 1/1\n",
      "100/100 [==============================] - 10s 104ms/step - loss: 2.3081 - hinge: 0.9900\n",
      "Epoch 1/1\n",
      "100/100 [==============================] - 11s 105ms/step - loss: 2.3081 - logcosh: 0.0405\n",
      "Epoch 1/1\n",
      "100/100 [==============================] - 11s 106ms/step - loss: 2.2982 - categorical_crossentropy: 2.2982\n",
      "Epoch 1/1\n",
      "100/100 [==============================] - 11s 108ms/step - loss: 2.3081 - binary_crossentropy: 0.3257\n",
      "Epoch 1/1\n",
      "100/100 [==============================] - 11s 107ms/step - loss: 2.3081 - kullback_leibler_divergence: 2.3081\n",
      "Epoch 1/1\n",
      "100/100 [==============================] - 11s 107ms/step - loss: 2.3081 - poisson: 0.3308\n",
      "Epoch 1/1\n",
      "100/100 [==============================] - 11s 108ms/step - loss: 2.3081 - cosine_proximity: -0.3143\n",
      "target shape: (100, 10)\n",
      "working metrics:\n",
      "accuracy\n",
      "binary_accuracy\n",
      "categorical_accuracy\n",
      "top_k_categorical_accuracy\n",
      "mean_squared_error\n",
      "mean_absolute_error\n",
      "mean_absolute_percentage_error\n",
      "mean_squared_logarithmic_error\n",
      "squared_hinge\n",
      "hinge\n",
      "logcosh\n",
      "categorical_crossentropy\n",
      "binary_crossentropy\n",
      "kullback_leibler_divergence\n",
      "poisson\n",
      "cosine_proximity\n"
     ]
    }
   ],
   "source": [
    "(X, y), (x_test, y_test) = cifar10.load_data()\n",
    "\n",
    "X = X[0:100]\n",
    "y = y[0:100]\n",
    "\n",
    "X = X.reshape(100, 32, 32, 3)\n",
    "y = to_categorical(y)\n",
    "\n",
    "metric_list = test_metrics(metrics_keras, config_keras_cla, X, y)\n",
    "\n",
    "print(\"target shape:\", y.shape)\n",
    "print(\"working metrics:\")\n",
    "for metric in metric_list:\n",
    "    print(metric)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": []
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.7.3"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 2
}
